from openai import OpenAI
import json
import csv
import pandas as pd
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from huggingface_hub import snapshot_download
import numpy as np
from gensim.models import KeyedVectors
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
import joblib

from reasoner.reasoner import Reasoner

client = OpenAI(
    api_key="your_api_key",
    base_url="https://api.deepseek.com/v1",
)


def load_knowledge_base(csv_path):
    knowledge_lines = []
    
    try:
        with open(csv_path, mode='r', encoding='utf-8-sig') as file:
            reader = csv.DictReader(file)
            
            if not {'symptom', 'code'}.issubset(reader.fieldnames):
                raise ValueError("The CSV file must contain columns 'symptom'and 'code'")
            
            for row in reader:
                symptom = row['symptom'].strip()
                code = row['code'].strip().replace('.', '')
                knowledge_lines.append(f"{symptom} -> {code}")
                
        return "\n".join(knowledge_lines)
    
    except FileNotFoundError:
        print(f"ERROR：File {csv_path} not found.")
        return ""
    except Exception as e:
        print(f"Failed to read CSV file:{str(e)}")
        return ""

# Initialize BioBERT medical entity recognition component
class BioBertEntityExtractor:
    def __init__(self, device=None):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load pre trained model (using fine tuned version on BC5CDR dataset)
        self.model = AutoModelForTokenClassification.from_pretrained(
            "./biobert_model",
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "./biobert_model",
        )
        
        # Entity label mapping (based on the label format of the BC5CDR dataset)
        self.label_map = {
            0: "O",
            1: "B-Disease",
            2: "I-Disease",
            3: "B-Chemical",
            4: "I-Chemical"
        }
        
        # Initialize pipeline
        self.ner_pipeline = pipeline(
            "ner",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if self.device.type == "cuda" else -1,
            aggregation_strategy="simple"
        )

    def extract_entities(self, text):
        """
        Rerurn：[{"entity_group": "Disease", "word": "pneumonia", "start": 20, "end": 29}, ...]
        """
        try:
            if not text.strip():
                return []
            
            # Perform NER prediction
            results = self.ner_pipeline(text)
            
            # Filter and standardize the results
            filtered = []
            for entity in results:
                # Merge chemical and disease types
                if "Chemical" in entity["entity_group"]:
                    entity["entity_group"] = "Drug"
                elif "Disease" in entity["entity_group"]:
                    entity["entity_group"] = "Disease"
                
                # Filter non-medical entities
                if entity["entity_group"] in ["Disease", "Drug"]:
                    # Clean entity text
                    entity["word"] = self._clean_entity(entity["word"])
                    filtered.append(entity)
            
            return filtered
        
        except Exception as e:
            print(f"Entity extraction failed: {str(e)}")
            return []

    def _clean_entity(self, text):
        text = text.replace("##", "").replace("_", " ").strip()
        return text.lower()



# Example usage
csv_path = "./ICD9_symptom_mapping.csv"
knowledge_base = load_knowledge_base(csv_path)

def get_promote(text, entities):
    return f"""
# Role Setting
You are a medical coding assistant that strictly follows a whitelist and can only use the following preset codes:

{knowledge_base}

# Input Format
Input as a series of users' natural language:

{text}

# Medical Knowledge Context
You have identified the following clinically relevant entities:

{entities}

# Task Requirements
1. Identify the severity of these entities(generally in English) from the input. Ignore the parts of the patient's language that are unrelated to the symptoms. If the user's input is unrelated to the medical field, it is possible to have a conversation with the user without the set personality.
2. Every entity must match only the standard terms in this strict list {knowledge_base}. This entity can match the similar meaning of an element in a strict list, or it can match the symptoms generated by an element in a strict list.
3. Sort the matches by severity(descending order).
4. Return in JSON format.


# Output Format
{{
    "diagnoses": [
        {{
            "standard_term": "Standardized medical term 1",
            "match_status": "Matched/Unknown",  # Matching status
            "severity": 10,            # Severity, which is 0 if Matching status is Unknown
            "icd9_code": ""                    # Leave empty if unmatched
        }},
        {{
            "standard_term": "Standardized medical term 2",
            "match_status": "Matched/Unknown",
            "severity": 4,
            "icd9_code": ""
        }}
    ]
}}

# Example
Input: My hands and feet feel constantly numb, like I’m wearing invisible gloves and socks all the time.
Output:
{{
    "diagnoses": [
        {{"standard_term": "Shigella boydii", "match_status": "Matched", "severity": 2, "icd9_code": "42"}},
        {{"standard_term": "Poisoning by other diuretics", "match_status": "Matched", "severity": 1, "icd9_code": "9744"}}
    ]
}}
"""

def get_promote2(pre_res, intermediate):
    return f"""
# Role Setting
You are a senior medical translator who needs to translate professional biomarker information into natural language explanations that are easy for patients to understand.
Previously, through analysis, you obtained some possible symptoms from the patient's description:
{pre_res}
The format of this data is:
[('4254', <TruthValue: %1.00;0.99% (k=1)>),
 ('42789', <TruthValue: %1.00;0.99% (k=1)>),
 ('3320', <TruthValue: %1.00;0.98% (k=1)>),
 ('3591', <TruthValue: %1.00;0.96% (k=1)>),
 ('4271', <TruthValue: %1.00;0.96% (k=1)>)]
In this list, the first element of each tuple is the ICD9 number of the disease, and the mapping rule is based on {knowledge_base}. The second one is the true value. The format of the truth value is% f; c%， The higher the value of c, the greater the degree of certainty.

Now, you have obtained an intermediate result from the backend, which is a list of genes and proteins related to these diseases:
{intermediate}
The format of this data is:
[('GENE:32600', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29744', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32213', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:33672', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29386', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29404', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:30994', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32523', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32532', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:24563', <TruthValue: %1.00;0.90% (k=1)>)]
In this list, the first element of each tuple is the gene number, and the second element is also the truth value, in the format of% f; c%， The higher the value of c, the greater the degree of certainty.

With these two pieces of information, you need to:
1. Explain to the patient what diseases they may have. 
2. If the user further inquires about the details of the illness or the cause of the illness, then explain to the patient why these symptoms occur, which genes are related to them in the body, and name up to five related genes. If there is no further inquiry, the user may not be presented with genetic related reasoning.
3. Replace professional medical terminology with more common vocabulary.
4. Finally, provide 1-3 specific suggestions and a disclaimer. It is also necessary to consult a professional doctor at the hospital.
"""


def extract_icd9_strict(text, entities):
    full_prompt = get_promote(text, entities) + f"\n输入：{text}\n输出："
    
    response = client.chat.completions.create(
        model="deepseek-reasoner",
        messages=[{"role": "user", "content": full_prompt}],
        temperature=0.0,  # Completely deterministic output
        response_format={"type": "json_object"}
    )
    
    try:
        result = json.loads(response.choices[0].message.content)
        # Post processing validation
        valid_codes = [line.split(" -> ")[1].strip() for line in knowledge_base.split("\n") if line]
        for item in result["diagnoses"]:
            if item["icd9_code"] not in valid_codes:
                item.update({"match_status": "Unknown", "icd9_code": ""})
        return result
    except Exception as e:
        print(f"Error: {str(e)}")
        return {"diagnoses": []}

def get_explaination(text, entities):
    full_prompt = get_promote2(text, entities)
    
    try:
        response = client.chat.completions.create(
            model="deepseek-reasoner",
            messages=[{"role": "user", "content": full_prompt}],
            temperature=0.5,  # Appropriately increase to 0.5-0.6 to enhance readability
            max_tokens=1024,
            stop=["---"]
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error generating explanation：{str(e)}"
    
def process_protein_list(protein_list):
    processed = []
    for item in protein_list:
        if isinstance(item, str) and item.startswith('PROTEIN:'):
            processed.append(item)
        else:
            try:
                num = int(item)
                processed.append(f'PROTEIN:{num}')
            except:
                processed.append(f'PROTEIN:{item}')
    return processed
    
def procedure2gene(procedures):
    # === Load Pretrained Models ===
    model = KeyedVectors.load_word2vec_format("node2vec_embeddings.txt")
    clf = joblib.load("rf_model.pkl")

    # test_procedures = ['p123', 'p456']  # replace with your actual procedure codes, this is what you get from LLM

    # === Collect All Possible Proteins from Vocabulary (excluding procedures) ===
    all_nodes = list(model.key_to_index)
    candidate_proteins = [node for node in all_nodes if node not in procedures]

    # === Predict Proteins for Each Procedure ===
    predicted_links = []

    for proc in procedures:
        proc_links = []
        for prot in candidate_proteins:
            if proc in model and prot in model:
                features = np.concatenate((model[proc], model[prot])).reshape(1, -1)
                pred = clf.predict(features)[0]
                if pred == 1:
                    proc_links.append(prot)
        predicted_links.append({"procedure": proc, "protein": proc_links})

    # === Format and Display Result Table ===
    df_result = pd.DataFrame(predicted_links)
    df_result['protein'] = df_result['protein'].apply(process_protein_list)
    # proteins = df_result['protein']

    
    reasoner = Reasoner(5)
    all_results = []
    all_intermediates = []
    for protein_list in df_result['protein']:
        result, intermediate_results = reasoner.reason(protein_list)
        all_results.extend(result)
        all_intermediates.extend(intermediate_results)
    
    def get_c_value(truth_value):
        parts = str(truth_value).split(";")
        c_str = parts[1].split("%")[0]
        return float(c_str)
    
    sorted_results = sorted(
        all_results,
        key=lambda x: get_c_value(x[1]),
        reverse=True
    )
    sorted_intermediates = sorted(
        all_intermediates,
        key=lambda x: get_c_value(x[1]),
        reverse=True
    )

    return sorted_results, sorted_intermediates


def get_deepseek_response(input_text):
    biobert_extractor = BioBertEntityExtractor()
    entities = biobert_extractor.extract_entities(input_text)
    print("Extracting...")
    output = extract_icd9_strict(input_text, entities)
    # standard_terms = [entry["standard_term"] for entry in output["diagnoses"]]
    icd9_codes = [entry["icd9_code"] for entry in output["diagnoses"]]
    print("Reasoning...")
    if len(icd9_codes) > 0:
        sorted_results, sorted_intermediates = procedure2gene(icd9_codes)
    else:
        sorted_results = []
        sorted_intermediates = []
    print("Generating...")
    res = get_explaination(sorted_results, sorted_intermediates)
    return res

# Example usage
# Initialize extractor
sample_text = "I’ve been completely knocked out—so drowsy I can’t keep my eyes open, and my mind feels scrambled, like I’m trapped in a nightmare. My stomach is cramping so hard I’m doubled over, vomiting even bile. My heart is racing one minute and fluttering the next, and I’m gasping for air like I’m underwater. The scariest part? My skin blisters and turns fiery red within minutes of sunlight—like I’m allergic to the sun itself. My vision’s gone blurry, and my muscles are so weak I can’t even hold a glass. I took some old pills for sleep and an antibiotic last night... Now I feel like my body’s breaking down in every way possible."
print("User Input:")
print(sample_text)
print('\n')
biobert_extractor = BioBertEntityExtractor()
entities = biobert_extractor.extract_entities(sample_text)
output = extract_icd9_strict(sample_text, entities)
standard_terms = [entry["standard_term"] for entry in output["diagnoses"]]
icd9_codes = [entry["icd9_code"] for entry in output["diagnoses"]]
print(output)
sorted_results, sorted_intermediates = procedure2gene(icd9_codes)

res = get_explaination(sorted_results, sorted_intermediates)
print("Diagnosis Output:")
print(res)